{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch \n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import sklearn as sk\n",
    "from sklearn import preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_0_F.shape:  (700000, 15)\n",
      "train_1_F.shape:  (42000, 15)\n"
     ]
    }
   ],
   "source": [
    "train_0_F = np.loadtxt(open(\"../../data/motor_fault/train_0_F.csv\",\"rb\"), delimiter=\",\", skiprows=0)\n",
    "train_1_F = np.loadtxt(open(\"../../data/motor_fault/train_1_F.csv\",\"rb\"), delimiter=\",\", skiprows=0)\n",
    "print(\"train_0_F.shape: \", train_0_F.shape)\n",
    "print(\"train_1_F.shape: \", train_1_F.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_0_F[:10, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 高斯归一化(注意，归一化不应该对label（最后一列）进行操作)\n",
    "train_0_F[:,:-1] = preprocessing.scale(train_0_F[:,:-1])\n",
    "train_1_F[:,:-1] = preprocessing.scale(train_1_F[:,:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_0_F[:10, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-parameters\n",
    "input_size = 14\n",
    "hidden_size1 = 100\n",
    "hidden_size2 = 60\n",
    "hidden_size3 = 30\n",
    "num_classes = 2\n",
    "num_epochs = 10\n",
    "t_max = 10\n",
    "batch_size = 100\n",
    "learning_rate = 1e-2\n",
    "valid_size = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.size: torch.Size([742000, 14])\n",
      "y.size: torch.Size([742000])\n"
     ]
    }
   ],
   "source": [
    "# Dataset\n",
    "train_F = np.concatenate((train_0_F, train_1_F), axis=0)\n",
    "x = torch.tensor(train_F[:,:-1], dtype=torch.float32)\n",
    "y = torch.tensor(train_F[:,-1], dtype=torch.long)\n",
    "print(\"x.size: {}\".format(x.size()))\n",
    "print(\"y.size: {}\".format(y.size()))\n",
    "dataset = torch.utils.data.TensorDataset(x, y)\n",
    "\n",
    "num_train = len(dataset)\n",
    "indices = list(range(num_train))\n",
    "split = int(np.floor(valid_size * num_train))\n",
    "np.random.shuffle(indices)\n",
    "train_idx, test_idx = indices[:split], indices[split:]\n",
    "trainset = torch.utils.data.Subset(dataset=dataset, indices=train_idx)\n",
    "valset = torch.utils.data.Subset(dataset=dataset, indices=test_idx)\n",
    "\n",
    "trainloader = torch.utils.data.DataLoader(trainset,\n",
    "                                     batch_size = batch_size,\n",
    "                                     shuffle = True,\n",
    "                                     num_workers = 2)\n",
    "valloader = torch.utils.data.DataLoader(valset,\n",
    "                                     batch_size = batch_size,\n",
    "                                     shuffle = False,\n",
    "                                     num_workers = 2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.31861003, -0.03385312, -0.03910605, -0.18806567,  0.22277755,\n",
       "        -0.54947453,  0.020132  ,  0.56125054,  0.19723481,  0.5673156 ,\n",
       "         0.81437145, -0.70651987,  0.68691267, -0.56649781,  0.        ],\n",
       "       [-0.31861003, -0.0337537 , -0.03915167, -0.18803359,  0.22290679,\n",
       "        -0.54939087,  0.02017467,  0.56125054,  0.19638423,  0.56700683,\n",
       "         0.81619242, -0.7033982 ,  0.69121001, -0.56619137,  0.        ],\n",
       "       [-0.31861003, -0.03392825, -0.03912249, -0.18808508,  0.22285648,\n",
       "        -0.54953776,  0.02014738,  0.56125054,  0.19545258,  0.56783042,\n",
       "         0.81938876, -0.70054403,  0.69592435, -0.56700874,  0.        ],\n",
       "       [-0.31861003, -0.03369348, -0.03921473, -0.18801082,  0.22340346,\n",
       "        -0.5493402 ,  0.02023367,  0.56125054,  0.19480136,  0.56556227,\n",
       "         0.81869033, -0.6971952 ,  0.69922429, -0.56475771,  0.        ],\n",
       "       [-0.31861003, -0.0341658 , -0.03940782, -0.18812439,  0.22470546,\n",
       "        -0.54973775,  0.0204143 ,  0.56125054,  0.19344178,  0.56607866,\n",
       "         0.82265764, -0.69373627,  0.70612598, -0.5652702 ,  0.        ],\n",
       "       [-0.31861003, -0.03371225, -0.03978293, -0.18796226,  0.22742175,\n",
       "        -0.54935599,  0.0207652 ,  0.56125054,  0.19117528,  0.56286443,\n",
       "         0.82507231, -0.68598209,  0.71766869, -0.56208021,  0.        ],\n",
       "       [-0.31861003, -0.03380684, -0.03986823, -0.18798058,  0.22788559,\n",
       "        -0.54943559,  0.02084499,  0.56125054,  0.18898288,  0.55609695,\n",
       "         0.82362109, -0.67856348,  0.72887839, -0.55536379,  0.        ],\n",
       "       [-0.31861003, -0.03425537, -0.04038675, -0.18805668,  0.23063106,\n",
       "        -0.54981316,  0.02133005,  0.56125054,  0.18691363,  0.5487965 ,\n",
       "         0.82129687, -0.67112573,  0.73949869, -0.54811841,  0.        ],\n",
       "       [-0.31861003, -0.03410679, -0.04068756, -0.1879867 ,  0.23285258,\n",
       "        -0.54968806,  0.02161144,  0.56125054,  0.18598485,  0.54514421,\n",
       "         0.81986017, -0.66778517,  0.74427836, -0.54449368,  0.        ],\n",
       "       [-0.31861003, -0.03365657, -0.0406108 , -0.18786833,  0.23262337,\n",
       "        -0.54930914,  0.02153964,  0.56125054,  0.18593857,  0.54108954,\n",
       "         0.81575894, -0.66567619,  0.74451677, -0.5404696 ,  0.        ]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_0_F[0:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(42000)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fully connected neural network with one hidden layer\n",
    "class NeuralNet(nn.Module):\n",
    "    def __init__(self, input_size=input_size, hidden_size1=hidden_size1, \n",
    "                 hidden_size2=hidden_size2, hidden_size3=hidden_size3, num_classes=num_classes):\n",
    "        super(NeuralNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size1)\n",
    "        self.relu = nn.LeakyReLU()\n",
    "        self.fc2 = nn.Linear(hidden_size1, hidden_size2)\n",
    "        self.fc3 = nn.Linear(hidden_size2, hidden_size3)\n",
    "        self.fc4 = nn.Linear(hidden_size3, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.fc1(x)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc3(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc4(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = NeuralNet().to(device)\n",
    "weight=torch.tensor([30,500]).to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)\n",
    "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = t_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(loader):\n",
    "    # test on all test data\n",
    "    correct = 0 \n",
    "    total = 0\n",
    "    total0 = 0\n",
    "    total1 = 0\n",
    "    p00 = 0 #第一个0代表label，第二个0代表预测值\n",
    "    p01 = 0\n",
    "    p10 = 0\n",
    "    p11 = 0    \n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for x, y in loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "            outputs = model(x)\n",
    "            predicted = torch.max(outputs.data, 1)[1]\n",
    "            total += y.size(0)\n",
    "            total1 += sum(y.cpu().numpy())\n",
    "            correct += (predicted == y).sum().item()\n",
    "            for i in range(batch_size):\n",
    "                if y[i] == 0 and predicted[i] == 0:\n",
    "                    p00 += 1\n",
    "                elif y[i] == 0 and predicted[i] == 1:\n",
    "                    p01 += 1\n",
    "                elif y[i] == 1 and predicted[i] == 0:\n",
    "                    p10 += 1\n",
    "                elif y[i] == 1 and predicted[i] == 1:\n",
    "                    p11 += 1\n",
    "                \n",
    "    total0 = total - total1\n",
    "    my_acc = correct / total\n",
    "    my_recall = p11 / total1\n",
    "    my_precision = p00 / total0 #大赛的要求，和课本上的精度定义不一样\n",
    "    \n",
    "    model.train()\n",
    "    #print('Test_acc: {:.2f}%'.format(100 * correct / total))\n",
    "    return my_acc, my_recall, my_precision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1-3/(50+3) = 0.943"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_acc: 0.0570, train_recall: 1.0000, train_precision: 0.0000\n",
      "val_acc: 0.0565, val_recall: 1.0000, val_precision: 0.0000\n"
     ]
    }
   ],
   "source": [
    "train_acc, train_recall, train_precision = test(trainloader)\n",
    "print(\"train_acc: {:.4f}, train_recall: {:.4f}, train_precision: {:.4f}\".format(train_acc, train_recall, train_precision))\n",
    "val_acc, val_recall, val_precision = test(valloader)\n",
    "print(\"val_acc: {:.4f}, val_recall: {:.4f}, val_precision: {:.4f}\".format(val_acc, val_recall, val_precision))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: [1/10], Step: [100/1484], Loss: 0.3398\n",
      "Epoch: [1/10], Step: [200/1484], Loss: 0.2120\n",
      "Epoch: [1/10], Step: [300/1484], Loss: 0.2171\n",
      "Epoch: [1/10], Step: [400/1484], Loss: 0.1968\n",
      "Epoch: [1/10], Step: [500/1484], Loss: 0.1906\n",
      "Epoch: [1/10], Step: [600/1484], Loss: 0.1705\n",
      "Epoch: [1/10], Step: [700/1484], Loss: 0.1621\n",
      "Epoch: [1/10], Step: [800/1484], Loss: 0.1481\n",
      "Epoch: [1/10], Step: [900/1484], Loss: 0.1356\n",
      "Epoch: [1/10], Step: [1000/1484], Loss: 0.1217\n",
      "Epoch: [1/10], Step: [1100/1484], Loss: 0.1017\n",
      "Epoch: [1/10], Step: [1200/1484], Loss: 0.0939\n",
      "Epoch: [1/10], Step: [1300/1484], Loss: 0.0758\n",
      "Epoch: [1/10], Step: [1400/1484], Loss: 0.0628\n",
      "train_acc: 0.9856, train_recall: 0.7516, train_precision: 0.9997\n",
      "val_acc: 0.9857, val_recall: 0.7519, val_precision: 0.9997\n",
      "Epoch: [2/10], Step: [100/1484], Loss: 0.0477\n",
      "Epoch: [2/10], Step: [200/1484], Loss: 0.0418\n",
      "Epoch: [2/10], Step: [300/1484], Loss: 0.0408\n",
      "Epoch: [2/10], Step: [400/1484], Loss: 0.0338\n",
      "Epoch: [2/10], Step: [500/1484], Loss: 0.0285\n",
      "Epoch: [2/10], Step: [600/1484], Loss: 0.0240\n",
      "Epoch: [2/10], Step: [700/1484], Loss: 0.0360\n",
      "Epoch: [2/10], Step: [800/1484], Loss: 0.0191\n",
      "Epoch: [2/10], Step: [900/1484], Loss: 0.0201\n",
      "Epoch: [2/10], Step: [1000/1484], Loss: 0.0138\n",
      "Epoch: [2/10], Step: [1100/1484], Loss: 0.0138\n",
      "Epoch: [2/10], Step: [1200/1484], Loss: 0.0136\n",
      "Epoch: [2/10], Step: [1300/1484], Loss: 0.0141\n",
      "Epoch: [2/10], Step: [1400/1484], Loss: 0.0151\n",
      "train_acc: 0.9988, train_recall: 0.9832, train_precision: 0.9998\n",
      "val_acc: 0.9987, val_recall: 0.9801, val_precision: 0.9998\n",
      "Epoch: [3/10], Step: [100/1484], Loss: 0.0120\n",
      "Epoch: [3/10], Step: [200/1484], Loss: 0.0137\n",
      "Epoch: [3/10], Step: [300/1484], Loss: 0.0095\n",
      "Epoch: [3/10], Step: [400/1484], Loss: 0.0071\n",
      "Epoch: [3/10], Step: [500/1484], Loss: 0.0059\n",
      "Epoch: [3/10], Step: [600/1484], Loss: 0.0059\n",
      "Epoch: [3/10], Step: [700/1484], Loss: 0.0063\n",
      "Epoch: [3/10], Step: [800/1484], Loss: 0.0044\n",
      "Epoch: [3/10], Step: [900/1484], Loss: 0.0049\n",
      "Epoch: [3/10], Step: [1000/1484], Loss: 0.0068\n",
      "Epoch: [3/10], Step: [1100/1484], Loss: 0.0039\n",
      "Epoch: [3/10], Step: [1200/1484], Loss: 0.0040\n",
      "Epoch: [3/10], Step: [1300/1484], Loss: 0.0377\n",
      "Epoch: [3/10], Step: [1400/1484], Loss: 0.0314\n",
      "train_acc: 0.9936, train_recall: 0.8879, train_precision: 1.0000\n",
      "val_acc: 0.9938, val_recall: 0.8901, val_precision: 1.0000\n",
      "Epoch: [4/10], Step: [100/1484], Loss: 0.0116\n",
      "Epoch: [4/10], Step: [200/1484], Loss: 0.0151\n",
      "Epoch: [4/10], Step: [300/1484], Loss: 0.0220\n",
      "Epoch: [4/10], Step: [400/1484], Loss: 0.0076\n",
      "Epoch: [4/10], Step: [500/1484], Loss: 0.0062\n",
      "Epoch: [4/10], Step: [600/1484], Loss: 0.0062\n",
      "Epoch: [4/10], Step: [700/1484], Loss: 0.0046\n",
      "Epoch: [4/10], Step: [800/1484], Loss: 0.0078\n",
      "Epoch: [4/10], Step: [900/1484], Loss: 0.0035\n",
      "Epoch: [4/10], Step: [1000/1484], Loss: 0.0115\n",
      "Epoch: [4/10], Step: [1100/1484], Loss: 0.0039\n",
      "Epoch: [4/10], Step: [1200/1484], Loss: 0.0038\n",
      "Epoch: [4/10], Step: [1300/1484], Loss: 0.0028\n",
      "Epoch: [4/10], Step: [1400/1484], Loss: 0.0022\n",
      "train_acc: 0.9973, train_recall: 0.9732, train_precision: 0.9988\n",
      "val_acc: 0.9973, val_recall: 0.9729, val_precision: 0.9988\n",
      "Epoch: [5/10], Step: [100/1484], Loss: 0.0271\n",
      "Epoch: [5/10], Step: [200/1484], Loss: 0.0048\n",
      "Epoch: [5/10], Step: [300/1484], Loss: 0.0098\n",
      "Epoch: [5/10], Step: [400/1484], Loss: 0.0040\n",
      "Epoch: [5/10], Step: [500/1484], Loss: 0.0031\n",
      "Epoch: [5/10], Step: [600/1484], Loss: 0.0029\n",
      "Epoch: [5/10], Step: [700/1484], Loss: 0.0027\n",
      "Epoch: [5/10], Step: [800/1484], Loss: 0.0080\n",
      "Epoch: [5/10], Step: [900/1484], Loss: 0.0021\n",
      "Epoch: [5/10], Step: [1000/1484], Loss: 0.0021\n",
      "Epoch: [5/10], Step: [1100/1484], Loss: 0.0020\n",
      "Epoch: [5/10], Step: [1200/1484], Loss: 0.0020\n",
      "Epoch: [5/10], Step: [1300/1484], Loss: 0.0019\n",
      "Epoch: [5/10], Step: [1400/1484], Loss: 0.0015\n",
      "train_acc: 0.9998, train_recall: 0.9970, train_precision: 1.0000\n",
      "val_acc: 0.9998, val_recall: 0.9962, val_precision: 1.0000\n",
      "Epoch: [6/10], Step: [100/1484], Loss: 0.0037\n",
      "Epoch: [6/10], Step: [200/1484], Loss: 0.0018\n",
      "Epoch: [6/10], Step: [300/1484], Loss: 0.0012\n",
      "Epoch: [6/10], Step: [400/1484], Loss: 0.0015\n",
      "Epoch: [6/10], Step: [500/1484], Loss: 0.0013\n",
      "Epoch: [6/10], Step: [600/1484], Loss: 0.0012\n",
      "Epoch: [6/10], Step: [700/1484], Loss: 0.0016\n",
      "Epoch: [6/10], Step: [800/1484], Loss: 0.0022\n",
      "Epoch: [6/10], Step: [900/1484], Loss: 0.0013\n",
      "Epoch: [6/10], Step: [1000/1484], Loss: 0.0013\n",
      "Epoch: [6/10], Step: [1100/1484], Loss: 0.0039\n",
      "Epoch: [6/10], Step: [1200/1484], Loss: 0.0023\n",
      "Epoch: [6/10], Step: [1300/1484], Loss: 0.0010\n",
      "Epoch: [6/10], Step: [1400/1484], Loss: 0.0010\n",
      "train_acc: 0.9999, train_recall: 0.9993, train_precision: 1.0000\n",
      "val_acc: 0.9999, val_recall: 0.9988, val_precision: 1.0000\n",
      "Epoch: [7/10], Step: [100/1484], Loss: 0.0009\n",
      "Epoch: [7/10], Step: [200/1484], Loss: 0.0007\n",
      "Epoch: [7/10], Step: [300/1484], Loss: 0.0010\n",
      "Epoch: [7/10], Step: [400/1484], Loss: 0.0041\n",
      "Epoch: [7/10], Step: [500/1484], Loss: 0.0012\n",
      "Epoch: [7/10], Step: [600/1484], Loss: 0.0008\n",
      "Epoch: [7/10], Step: [700/1484], Loss: 0.0009\n",
      "Epoch: [7/10], Step: [800/1484], Loss: 0.0008\n",
      "Epoch: [7/10], Step: [900/1484], Loss: 0.0008\n",
      "Epoch: [7/10], Step: [1000/1484], Loss: 0.0007\n",
      "Epoch: [7/10], Step: [1100/1484], Loss: 0.0007\n",
      "Epoch: [7/10], Step: [1200/1484], Loss: 0.0009\n",
      "Epoch: [7/10], Step: [1300/1484], Loss: 0.0008\n",
      "Epoch: [7/10], Step: [1400/1484], Loss: 0.0006\n",
      "train_acc: 1.0000, train_recall: 0.9995, train_precision: 1.0000\n",
      "val_acc: 0.9999, val_recall: 0.9988, val_precision: 1.0000\n",
      "Epoch: [8/10], Step: [100/1484], Loss: 0.0005\n",
      "Epoch: [8/10], Step: [200/1484], Loss: 0.0005\n",
      "Epoch: [8/10], Step: [300/1484], Loss: 0.0007\n",
      "Epoch: [8/10], Step: [400/1484], Loss: 0.0005\n",
      "Epoch: [8/10], Step: [500/1484], Loss: 0.0005\n",
      "Epoch: [8/10], Step: [600/1484], Loss: 0.0005\n",
      "Epoch: [8/10], Step: [700/1484], Loss: 0.0004\n",
      "Epoch: [8/10], Step: [800/1484], Loss: 0.0029\n",
      "Epoch: [8/10], Step: [900/1484], Loss: 0.0007\n",
      "Epoch: [8/10], Step: [1000/1484], Loss: 0.0006\n",
      "Epoch: [8/10], Step: [1100/1484], Loss: 0.0008\n",
      "Epoch: [8/10], Step: [1200/1484], Loss: 0.0005\n",
      "Epoch: [8/10], Step: [1300/1484], Loss: 0.0004\n",
      "Epoch: [8/10], Step: [1400/1484], Loss: 0.0011\n",
      "train_acc: 1.0000, train_recall: 0.9996, train_precision: 1.0000\n",
      "val_acc: 0.9999, val_recall: 0.9988, val_precision: 1.0000\n",
      "Epoch: [9/10], Step: [100/1484], Loss: 0.0004\n",
      "Epoch: [9/10], Step: [200/1484], Loss: 0.0006\n",
      "Epoch: [9/10], Step: [300/1484], Loss: 0.0004\n",
      "Epoch: [9/10], Step: [400/1484], Loss: 0.0020\n",
      "Epoch: [9/10], Step: [500/1484], Loss: 0.0005\n",
      "Epoch: [9/10], Step: [600/1484], Loss: 0.0003\n",
      "Epoch: [9/10], Step: [700/1484], Loss: 0.0004\n",
      "Epoch: [9/10], Step: [800/1484], Loss: 0.0004\n",
      "Epoch: [9/10], Step: [900/1484], Loss: 0.0005\n",
      "Epoch: [9/10], Step: [1000/1484], Loss: 0.0005\n",
      "Epoch: [9/10], Step: [1100/1484], Loss: 0.0004\n",
      "Epoch: [9/10], Step: [1200/1484], Loss: 0.0004\n",
      "Epoch: [9/10], Step: [1300/1484], Loss: 0.0006\n",
      "Epoch: [9/10], Step: [1400/1484], Loss: 0.0003\n",
      "train_acc: 0.9998, train_recall: 0.9999, train_precision: 0.9998\n",
      "val_acc: 0.9998, val_recall: 0.9993, val_precision: 0.9998\n",
      "Epoch: [10/10], Step: [100/1484], Loss: 0.0003\n",
      "Epoch: [10/10], Step: [200/1484], Loss: 0.0006\n",
      "Epoch: [10/10], Step: [300/1484], Loss: 0.0004\n",
      "Epoch: [10/10], Step: [400/1484], Loss: 0.0003\n",
      "Epoch: [10/10], Step: [500/1484], Loss: 0.0003\n",
      "Epoch: [10/10], Step: [600/1484], Loss: 0.0004\n",
      "Epoch: [10/10], Step: [700/1484], Loss: 0.0015\n",
      "Epoch: [10/10], Step: [800/1484], Loss: 0.0004\n",
      "Epoch: [10/10], Step: [900/1484], Loss: 0.0003\n",
      "Epoch: [10/10], Step: [1000/1484], Loss: 0.0006\n",
      "Epoch: [10/10], Step: [1100/1484], Loss: 0.0005\n",
      "Epoch: [10/10], Step: [1200/1484], Loss: 0.0003\n",
      "Epoch: [10/10], Step: [1300/1484], Loss: 0.0003\n",
      "Epoch: [10/10], Step: [1400/1484], Loss: 0.0002\n",
      "train_acc: 1.0000, train_recall: 0.9995, train_precision: 1.0000\n",
      "val_acc: 0.9999, val_recall: 0.9990, val_precision: 1.0000\n"
     ]
    }
   ],
   "source": [
    "# Train the model \n",
    "\n",
    "total_step = len(trainloader)\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    for i, (x, y) in enumerate(trainloader):\n",
    "        # Move tensors to the configured device\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        #print(\"batch_x: \", type(x), x.size())\n",
    "        #print(\"batch_y: \", type(y), y.size())\n",
    "        #print(x)\n",
    "        #print(y)\n",
    "        \n",
    "        # Forward pass\n",
    "        outputs = model(x)\n",
    "        #print(outputs)\n",
    "        loss = criterion(outputs,y)\n",
    "        \n",
    "        \n",
    "        # Backward and optimize\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item() \n",
    "        if (i+1) % 100 == 0:\n",
    "            print('Epoch: [{}/{}], Step: [{}/{}], Loss: {:.4f}'\n",
    "                  .format(epoch+1, num_epochs, i+1, total_step, running_loss/100))\n",
    "            running_loss = 0.0\n",
    "        lr_scheduler.step()\n",
    "    train_acc, train_recall, train_precision = test(trainloader)\n",
    "    print(\"train_acc: {:.4f}, train_recall: {:.4f}, train_precision: {:.4f}\".format(train_acc, train_recall, train_precision))\n",
    "    val_acc, val_recall, val_precision = test(valloader)\n",
    "    print(\"val_acc: {:.4f}, val_recall: {:.4f}, val_precision: {:.4f}\".format(val_acc, val_recall, val_precision))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
